import cv2 
import numpy as np

def polarizations_fusion(img_list, channel_list):
    '''
    2: red, 1:green, 0:blue 
    '''
    im_list = []
    for img in img_list:
        im_list.append(cv2.imread(img, cv2.IMREAD_UNCHANGED))

    bgr_img = np.zeros((im_list[0].shape[0], im_list[0].shape[1], 3))
    
    i = 0
    for img in im_list:
        bgr_img[:,:,channel_list[i]] = img[:,:,0]
        i+=1
    
    return np.uint8(bgr_img)

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    IMG =['/media/gejunyao/Disk1/Datasets/LS-SSDD-v1.0-OPEN/VOC2012/JPEGImages_full/04.jpg',
          '/media/gejunyao/Disk1/Datasets/LS-SSDD-v1.0-OPEN/VOC2012/JPEGImages_VH/04-VH.jpg']
    CLR = [2, 0]
    img = polarizations_fusion(IMG, CLR)

    plt.figure()
    plt.imshow(img[:,:,[2,1,0]])
    plt.show()

    cv2.imwrite('/media/gejunyao/Disk1/Datasets/LS-SSDD-v1.0-OPEN/VOC2012/JPEGImages_VH/04-rb.jpg',img)

    print('hello')